import torch
import torch.nn as nn
from models import *
def print_torch_model_params(model: nn.Module):
    print("===== PyTorch 模型参数 =====")
    for name, param in model.named_parameters():
        print(f"{name}: {tuple(param.shape)}")
model = MVSNet(refine=False)
print_torch_model_params(model)
